In [1]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import cv2
from sklearn.cluster import KMeans
In [2]:
df = pd.read_csv('artwork_img_kmean.csv')
In [3]:
sample_images = df.Artwork_Image_Path.sample(16).tolist(); sample_images
Out[3]:
['images2/sothebys/N08417/N08417-56-lr-1.jpg',
 'images2/asi2-102571/5.jpg',
 'images2/asi2-125565/652.jpg',
 'images2/asi2-117337/45.jpg',
 'images2/asi2-109125/66.jpg',
 'images2/asi2-89270/45.jpg',
 'images2/asi2-118745/5.jpg',
 'images2/sothebys/N08300/N08300-67-lr-1.jpg',
 'images2/asi2-108907/129.jpg',
 'images2/asi2-105601/32.jpg',
 'images2/asi2-83137/144.jpg',
 'images2/missingImages/0538057/112.jpg',
 'images2/SY.NY120040324/00182.jpg',
 'images2/asi2-105798/56.jpg',
 'images2/asi2-134206/235.jpg',
 'images2/0611441/477.JPG']
In [4]:
fig,axes=plt.subplots(4,4,figsize=(15,15))
for i,iax in enumerate(axes.flatten()):
    iax.imshow(plt.imread(sample_images[i]))
    iax.axis("off")
In [5]:
def find_color_palettes(imgpath):
    img = cv2.cvtColor(cv2.imread(imgpath), cv2.COLOR_BGR2RGB)
    nrows, ncols, nchns = img.shape
    X = img.reshape(nrows*ncols, nchns)
    km = KMeans(n_clusters=5, random_state=0).fit(X)
    assert km.n_iter_ < km.max_iter, "did not converge: iter(n={} max={})".format(km.n_iter_, km.max_iter)
    labels, counts = np.unique(km.labels_, return_counts=True)
    sorted_by_count = sorted(zip(labels, counts), key=lambda x:x[1], reverse=True)
    palettes_by_count = [km.cluster_centers_[label].round(0).astype(int).tolist()
                        for label, count in sorted_by_count]
    plt.subplots(2,1,figsize=(10,10),gridspec_kw={})
    axes =(plt.subplot2grid((2, 5), (0, 0), colspan=5),
           plt.subplot2grid((2, 5), (1, 0)),
           plt.subplot2grid((2, 5), (1, 1)),
           plt.subplot2grid((2, 5), (1, 2)),
           plt.subplot2grid((2, 5), (1, 3)),
           plt.subplot2grid((2, 5), (1, 4)))
    map(lambda ax:ax.axis("off"), axes)
    axes[0].imshow(img)
    for i, p in enumerate(palettes_by_count):
        axes[i+1].imshow([[palettes_by_count[i]]*4]*4)
    plt.tight_layout()
In [6]:
find_color_palettes(sample_images[0])
In [7]:
find_color_palettes(sample_images[1])
In [8]:
find_color_palettes(sample_images[2])
In [9]:
find_color_palettes(sample_images[3])
In [10]:
find_color_palettes(sample_images[4])
In [11]:
find_color_palettes(sample_images[5])
In [12]:
find_color_palettes(sample_images[6])
In [13]:
find_color_palettes(sample_images[7])
In [14]:
find_color_palettes(sample_images[8])
In [15]:
find_color_palettes(sample_images[9])
In [16]:
find_color_palettes(sample_images[10])
In [17]:
find_color_palettes(sample_images[11])
In [18]:
find_color_palettes(sample_images[12])
In [19]:
find_color_palettes(sample_images[13])
In [20]:
find_color_palettes(sample_images[14])
In [21]:
find_color_palettes(sample_images[15])